package com.googlecode.gaal.analysis.impl;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;

import com.googlecode.gaal.analysis.api.Context;
import com.googlecode.gaal.data.api.Corpus;
import com.googlecode.gaal.data.api.IntSequence;
import com.googlecode.gaal.data.api.IntervalSet;
import com.googlecode.gaal.data.api.Multiset;
import com.googlecode.gaal.data.impl.TreeMultiset;
import com.googlecode.gaal.suffix.api.IntervalTree.Interval;
import com.googlecode.gaal.suffix.api.LinearizedSuffixTree;
import com.googlecode.gaal.suffix.api.LinearizedSuffixTree.BinaryInterval;
import com.googlecode.gaal.suffix.impl.LinearizedSuffixTreeImpl;

public class NestedMaximalityContextExtractor<S> implements Iterable<Context> {
    private final LinearizedSuffixTree lst;
    private final LinearizedSuffixTree lpt;
    private final IntervalSet<BinaryInterval> lstMaximalSet;
    private final IntervalSet<BinaryInterval> lstBwtSet;
    private final IntervalSet<BinaryInterval> lptMaximalSet;
    private final Map<Context, Map<Interval, Integer>> envMap = new TreeMap<Context, Map<Interval, Integer>>();
    private final boolean maximalOnly;

    public NestedMaximalityContextExtractor(Corpus<S> corpus, boolean maximalOnly) {
        this.lst = new LinearizedSuffixTreeImpl(corpus.sequence(), corpus.alphabetSize());
        this.lpt = new LinearizedSuffixTreeImpl(corpus.sequence().reverse(), corpus.alphabetSize());
        lstMaximalSet = new LocalMaximumSetBuilder().buildIntervalSet(lst);
        lstBwtSet = new SingletonBwtSetBuilder().buildIntervalSet(lst);
        lptMaximalSet = new LocalMaximumSetBuilder().buildIntervalSet(lpt);
        this.maximalOnly = maximalOnly;
        traverseLeft(lst.top(), 0, new HashSet<Interval>());
        List<Context> oneFillTandems = new ArrayList<Context>();
        for (Map.Entry<Context, Map<Interval, Integer>> entry : envMap.entrySet()) {
            if (entry.getValue().size() == 1) {
                oneFillTandems.add(entry.getKey());
            }
        }
        for (Context tandem : oneFillTandems) {
            envMap.remove(tandem);
        }
    }

    public Map<Interval, Integer> getFill(NestedContext tandem) {
        return envMap.get(tandem);
    }

    private void traverseLeft(BinaryInterval interval, int parentLcp, Set<Interval> fillSet) {
        if (!interval.isTerminal()) {
            int lcp = interval.lcp();
            if (!lstMaximalSet.contains(interval)) {
                if (lcp > parentLcp) {
                    if (maximalOnly)
                        fillSet.clear();

                    fillSet.add(interval);
                }

            } else if (!fillSet.isEmpty() && !lstBwtSet.contains(interval)) {
                extendToRight(interval, fillSet);
                return;
            }
            traverseLeft(interval.leftChild(), lcp, new HashSet<Interval>(fillSet));
            traverseLeft(interval.rightChild(), lcp, new HashSet<Interval>(fillSet));
        }
    }

    private void extendToRight(Interval interval, Set<Interval> fillSet) {
        BinaryInterval lptInterval = lpt.search(interval.label().reverse());
        assert (lptInterval != null);
        assert (interval.label().size() == lptInterval.label().size());
        if (!lptMaximalSet.contains(lptInterval)) {
            collectLeftMaximal(lptInterval, lptInterval, interval, lptInterval.lcp(), fillSet);
        }
    }

    private void collectLeftMaximal(BinaryInterval interval, Interval parent, Interval lstInterval, int parentLcp,
            Set<Interval> fillSet) {
        if (!interval.isTerminal()) {
            int lcp = interval.lcp();
            if (!lptMaximalSet.contains(interval)) {
                if (!maximalOnly) {
                    if (lcp > parentLcp) {
                        addTandem(interval, parent, lstInterval, fillSet);
                    }
                }
            } else {
                addTandem(interval, parent, lstInterval, fillSet);
                return;
            }
            collectLeftMaximal(interval.leftChild(), parent, lstInterval, lcp, fillSet);
            collectLeftMaximal(interval.rightChild(), parent, lstInterval, lcp, fillSet);
        }
    }

    private void addTandem(Interval leftInterval, Interval leftParent, Interval rightInterval, Set<Interval> fillSet) {
        for (Interval fill : fillSet) {
            NestedContext env = new NestedContext(leftInterval.edgeLabel(leftParent).reverse(),
                    rightInterval.edgeLabel(fill));
            int count = min(rightInterval.size(), leftInterval.size());
            Map<Interval, Integer> fillMap = envMap.get(env);
            if (fillMap == null) {
                fillMap = new HashMap<Interval, Integer>();
                envMap.put(env, fillMap);
                fillMap.put(fill, count);
            } else {
                Integer currCount = fillMap.get(fill);
                if (currCount == null) {
                    fillMap.put(fill, count);
                } else {
                    fillMap.put(fill, currCount + count);
                }
            }
        }
    }

    private int min(int i, int j) {
        if (i < j)
            return i;
        else
            return j;
    }

    public class NestedContext implements Context, Comparable<NestedContext> {
        private final IntSequence left;
        private final IntSequence right;
        private Multiset<IntSequence> fillSet;

        protected NestedContext(IntSequence left, IntSequence right) {
            this.left = left;
            this.right = right;
        }

        @Override
        public IntSequence leftSequence() {
            return left;
        }

        @Override
        public IntSequence rightSequence() {
            return right;
        }

        @Override
        public Multiset<IntSequence> fillerSet() {
            if (fillSet == null) {
                fillSet = new TreeMultiset<IntSequence>();
                Iterator<Interval> iterator = envMap.get(this).keySet().iterator();
                while (iterator.hasNext()) {
                    fillSet.add(iterator.next().label());
                }
            }
            return fillSet;
        }

        @Override
        public int fillerSetSize() {
            return fillerSet().size();
        }

        @Override
        public int hashCode() {
            final int prime = 31;
            int result = 1;
            result = prime * result + ((left == null) ? 0 : left.hashCode());
            result = prime * result + ((right == null) ? 0 : right.hashCode());
            return result;
        }

        @Override
        public boolean equals(Object obj) {
            if (this == obj)
                return true;
            if (obj == null)
                return false;
            if (getClass() != obj.getClass())
                return false;
            Context other = (Context) obj;
            if (left == null) {
                if (other.leftSequence() != null)
                    return false;
            } else if (!left.equals(other.leftSequence()))
                return false;
            if (right == null) {
                if (other.rightSequence() != null)
                    return false;
            } else if (!right.equals(other.rightSequence()))
                return false;
            return true;
        }

        @Override
        public int compareTo(NestedContext other) {
            int leftCompare = left.compareTo(other.left);
            int rightCompare = right.compareTo(other.right);
            if (leftCompare != 0)
                return leftCompare;
            return rightCompare;
        }
    }

    @Override
    public Iterator<Context> iterator() {
        return envMap.keySet().iterator();
    }
}
